#pragma once

#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define CONVERT(addrptr) ((struct sockaddr *)addrptr)

namespace Net_Work
{
    const int default_backlog = 5;
    const int defaultsock = -1;
    
    enum class SocketErr
    {
        SocketError = 1,
        BindError,
        ListenError
    };

    class Socket
    {
    public:
        Socket()
        {
        }
        ~Socket()
        {
        }
        virtual void CreaterSocketOrDie() = 0;
        virtual void BindSocketOrDie(uint16_t port) = 0;
        virtual void ListenSocketOrDie(int backlog = default_backlog) = 0;
        virtual Socket *AcceptConnection(std::string *peerip, uint16_t *peerport) = 0;
        virtual bool ConnectServer(std::string &serverip, uint16_t serverport) = 0;
        virtual int GetSocket() = 0;
        virtual void SetSockFd(int sockfd) = 0;
        virtual void CloseSocket() = 0;

    public:
        void BuildListenSocketMethod(uint16_t port, int backlog)
        {
            CreaterSocketOrDie();
            BindSocketOrDie(port);
            ListenSocketOrDie(backlog);
        }

        void BulidConnectSocketMethod(std::string &serverip, uint16_t serverport)
        {
            CreaterSocketOrDie();
            ConnectServer(serverip, serverport);
        }

        void BuildNormalSocketMethod(int sockfd)
        {
            SetSockFd(sockfd);
        }
    };

    class TcpSocket : public Socket
    {
    public:
        TcpSocket(int sockfd = defaultsock)
            : _sockfd(sockfd)
        {
        }
        ~TcpSocket()
        {
        }
        void CreaterSocketOrDie() override
        {
            _sockfd = socket(AF_INET, SOCK_STREAM, 0);
            if (_sockfd < 0)
            {
                exit(static_cast<int>(SocketErr::SocketError));
            }
        }

        void BindSocketOrDie(uint16_t port) override
        {
            struct sockaddr_in local;
            memset(&local, 0, sizeof local);
            local.sin_family = AF_INET;
            local.sin_port = htons(port);
            local.sin_addr.s_addr = INADDR_ANY;

            if (::bind(_sockfd, CONVERT(&local), sizeof local) < 0)
            {
                exit(static_cast<int>(SocketErr::BindError));
            }
        }

        void ListenSocketOrDie(int backlog = default_backlog) override
        {
            if (::listen(_sockfd, backlog) != 0)
                exit(static_cast<int>(SocketErr::ListenError));
        }

        Socket *AcceptConnection(std::string *peerip, uint16_t *peerport) override
        {
            struct sockaddr_in peeraddr;
            socklen_t len = sizeof peeraddr;
            int newsockfd = ::accept(_sockfd, CONVERT(&peeraddr), &len);
            if (newsockfd < 0)
                return nullptr;
            *peerip = inet_ntoa(peeraddr.sin_addr);
            *peerport = ntohl(peeraddr.sin_port);
            Socket *s = new TcpSocket(newsockfd);
            return s;
        }

        bool ConnectServer(std::string &serverip, uint16_t serverport) override
        {
            struct sockaddr_in server;
            memset(&server, 0, sizeof(server));
            server.sin_family = AF_INET;
            server.sin_port = htons(serverport);
            server.sin_addr.s_addr = inet_addr(serverip.c_str());

            int n = ::connect(_sockfd, CONVERT(&server), sizeof(server));
            if (n == 0)
                return true;
            else
                return false;
        }

        int GetSocket() override
        {
            return _sockfd;
        }

        void SetSockFd(int sockfd) override
        {
            _sockfd = sockfd;
        }

        void CloseSocket()
        {
            if (_sockfd > 0)
                close(_sockfd);
        }

    private:
        int _sockfd;
    };
}